"""
This code is used to batch detect images in a folder.
"""
import argparse
import os
import sys

import cv2
import time

from vision.ssd.config.fd_config import define_img_size

parser = argparse.ArgumentParser(
    description='detect_imgs')

parser.add_argument('--net_type', default="RFB", type=str,
                    help='The network architecture ,optional: RFB (higher precision) or slim (faster)')
parser.add_argument('--input_size', default=640, type=int,
                    help='define network input size,default optional value 128/160/320/480/640/1280')
parser.add_argument('--threshold', default=0.6, type=float,
                    help='score threshold')
parser.add_argument('--candidate_size', default=1500, type=int,
                    help='nms candidate size')

parser.add_argument('--test_device', default="cuda:0", type=str,
                    help='cuda:0 or cpu')
args = parser.parse_args()
define_img_size(args.input_size)  # must put define_img_size() before 'import create_mb_tiny_fd, create_mb_tiny_fd_predictor'

from vision.ssd.mb_tiny_RFB_fd import create_Mb_Tiny_RFB_fd, create_Mb_Tiny_RFB_fd_predictor

result_path = "./detect_imgs_results"
label_path = "./models/voc-model-labels.txt"
test_device = args.test_device

class_names = [name.strip() for name in open(label_path).readlines()]

model_path = "models/pretrained/version-RFB-640.pth"
net = create_Mb_Tiny_RFB_fd(len(class_names), is_test=True, device=test_device)
predictor = create_Mb_Tiny_RFB_fd_predictor(net, candidate_size=args.candidate_size, device=test_device)

net.load(model_path)

# 打开默认摄像头（通常是索引为0）
cap = cv2.VideoCapture(0)

# 检查是否成功打开摄像头
if not cap.isOpened():
    print("无法打开摄像头")
    exit()

name = {"0":"BACKGROUND","1":"face"}

while True:
    # 读取一帧图像
    ret, frame = cap.read()

    # 如果读取成功，ret返回True
    if not ret:
        print("无法获取帧")
        break

    image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    start_time = time.time()
    boxes, labels, probs = predictor.predict(image, args.candidate_size / 2, args.threshold)
    Time = (time.time() - start_time)*1000
    print("模型检测时间：{}ms".format(Time))

    for i in range(boxes.size(0)):
        if probs[i] >= 0.9:
            box = boxes[i, :]
            cv2.rectangle(frame, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)
        cv2.putText(frame, class_names[labels[0]], (int(box[0]), int(box[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
# 显示图像
    cv2.imshow('实时视频', frame)

    # 按 'q' 键退出循环
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# 释放摄像头资源
cap.release()
# 关闭所有OpenCV窗口
cv2.destroyAllWindows()





